{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "2bhiZKFiIWbK"
      },
      "source": [
        "# Non-Federated EMNIST Baseline Training\n",
        "\n",
        "This colab has three main parts:\n",
        "\n",
        "*   It trains a non-federated model on a flattened and shuffled (that is,\n",
        "    non-federated) view of the the\n",
        "    [Federated EMNIST](https://www.tensorflow.org/federated/api_docs/python/tff/simulation/datasets/emnist/load_data)\n",
        "    dataset. The model architecture matches the simple CNN from the paper\n",
        "    [Communication-Efficient Learning of Deep Networks from Decentralized Data](https://arxiv.org/abs/1602.05629).\n",
        "    This (currently untuned) training reaches an accuracy of about 97% with\n",
        "    vanilla SGD. (Note these accuracy numbers are not directly comparable to\n",
        "    MNIST results, as the train and test datasets are different). This is\n",
        "    intended to serve as a baseline for simulated federated training on the Fed\n",
        "    EMNIST dataset.\n",
        "\n",
        "*   It uses this model to examine the Fed EMNIST dataset, showing it has\n",
        "    interesting variation across users.\n",
        "\n",
        "*   As a sanity check, it shows an equivalent model can be trained using the\n",
        "    `Federated Averaging` implementation from `tff.learning` applied to a\n",
        "    non-federated (that is, flattened and shuffled) view of the data.\n",
        "\n",
        "\n",
        "**Note:** This notebook will probably take ~25 minutes to fully execute."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "CZ2s96PebCkJ"
      },
      "outputs": [],
      "source": [
        "!pip install tensorflow_federated\n",
        "!pip install tensorflow_gan"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "h2CQ4u1H0I_B"
      },
      "outputs": [],
      "source": [
        "from __future__ import absolute_import, division, print_function\n",
        "\n",
        "import collections\n",
        "import functools\n",
        "import numpy as np\n",
        "import time\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import pandas as pd\n",
        "\n",
        "import tensorflow as tf\n",
        "import tensorflow_federated as tff\n",
        "import tensorflow_gan as tfgan\n",
        "\n",
        "tf.compat.v1.enable_v2_behavior()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "w3q3XlRja0tn"
      },
      "source": [
        "# Training a baseline model with Keras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "sPSSgR7kRv0z"
      },
      "source": [
        "## Data\n",
        "\n",
        "Download the data, and lightly reformat for use in Keras."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "lr7O_KnZR1jW"
      },
      "outputs": [],
      "source": [
        "emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "I9HmvLs_R7dK"
      },
      "outputs": [],
      "source": [
        "Example = collections.namedtuple('Example', ['x', 'y'])\n",
        "\n",
        "BATCH_SIZE = 10\n",
        "SHUFFLE_BUFFER = 10000\n",
        "\n",
        "\n",
        "def element_fn(element):\n",
        "  return Example(\n",
        "      x=tf.reshape(element['pixels'], [-1]),\n",
        "      y=tf.reshape(element['label'], [1]))\n",
        "\n",
        "\n",
        "def preprocess_train(dataset, batch_size=BATCH_SIZE):\n",
        "  return dataset.map(element_fn).apply(\n",
        "      tf.data.experimental.shuffle_and_repeat(\n",
        "          buffer_size=SHUFFLE_BUFFER, count=-1)).batch(batch_size)\n",
        "\n",
        "\n",
        "def preprocess_test(dataset):\n",
        "  return dataset.map(element_fn).batch(100, drop_remainder=False)\n",
        "\n",
        "\n",
        "flat_train_data = preprocess_train(\n",
        "    emnist_train.create_tf_dataset_from_all_clients(seed=739613565))\n",
        "flat_test_data = preprocess_test(\n",
        "    emnist_test.create_tf_dataset_from_all_clients(seed=686991103))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Sg31X5CLRu4m"
      },
      "source": [
        "## Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "K7MG76AhjxxL"
      },
      "outputs": [],
      "source": [
        "def build_cnn():\n",
        "  \"\"\"The CNN model used in https://arxiv.org/abs/1602.05629.\n",
        "\n",
        "  The number of parameters (1,663,370) matches what is reported in the paper.\n",
        "  \"\"\"\n",
        "  data_format = 'channels_last'\n",
        "  input_shape = [28, 28, 1]\n",
        "\n",
        "  # Alternatively:\n",
        "  # data_format = 'channels_first'\n",
        "  # input_shape = [1, 28, 28]\n",
        "\n",
        "  max_pool = lambda: tf.keras.layers.MaxPooling2D(\n",
        "      pool_size=(2, 2), padding='same', data_format=data_format)\n",
        "  conv2d = functools.partial(\n",
        "      tf.keras.layers.Conv2D,\n",
        "      kernel_size=5,\n",
        "      padding='same',\n",
        "      data_format=data_format,\n",
        "      activation=tf.nn.relu)\n",
        "\n",
        "  model = tf.keras.models.Sequential([\n",
        "      tf.keras.layers.Reshape(target_shape=input_shape, input_shape=(28 * 28,)),\n",
        "      conv2d(filters=32),\n",
        "      max_pool(),\n",
        "      conv2d(filters=64),\n",
        "      max_pool(),\n",
        "      tf.keras.layers.Flatten(),\n",
        "      tf.keras.layers.Dense(512, activation=tf.nn.relu),\n",
        "      tf.keras.layers.Dense(10, activation=tf.nn.softmax),\n",
        "  ])\n",
        "\n",
        "  model.compile(\n",
        "      loss=tf.keras.losses.sparse_categorical_crossentropy,\n",
        "      # This learning rate has not been tuned.\n",
        "      optimizer=tf.keras.optimizers.SGD(learning_rate=0.02),\n",
        "      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])\n",
        "\n",
        "  return model\n",
        "\n",
        "\n",
        "build_cnn().summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "JCBszLXUTzTP"
      },
      "source": [
        "## Training and evaluation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "W1QJ4WPosLGC"
      },
      "outputs": [],
      "source": [
        "NUM_ROUNDS = 10\n",
        "BATCHES_PER_ROUND = 1000"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "jlq0KEy3RndT"
      },
      "outputs": [],
      "source": [
        "model = build_cnn()\n",
        "# We set steps_per_epoch and epochs just to break training up in reasonable \"chunks\".\n",
        "# These aren't really epochs over the full dataset.\n",
        "# Training could take about 8 minutes.\n",
        "model.fit(flat_train_data, steps_per_epoch=BATCHES_PER_ROUND, epochs=NUM_ROUNDS)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "3eV1He4dfuRf"
      },
      "source": [
        "There are 40,832 test examples, so we take 409 eval steps with a test batch size\n",
        "of 100:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "bB1g32bFRpSC"
      },
      "outputs": [],
      "source": [
        "_ = model.evaluate(flat_test_data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "_dwWV0jqy-Y-"
      },
      "source": [
        "# Aside: Do some users have hard-to-classify data?\n",
        "\n",
        "Since this is a public dataset intended for research, one interesting thing we\n",
        "can do with the model is to use it to see if some users\n",
        "have hard-to-classify data. This is way of verifying that the Federated EMNIST\n",
        "dataset has interesting variation across users."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Dcsg0IUb0o77"
      },
      "outputs": [],
      "source": [
        "def display_raw_emnist(data, grid_width=25):\n",
        "  \"\"\"A helper function to display images from Fed EMNIST datasets.\"\"\"\n",
        "  # List of numpy images\n",
        "  img_data = np.array([x['pixels'].numpy() for x in data])\n",
        "  img_data = np.reshape(img_data, (-1, 28, 28, 1))\n",
        "  num_rows = int(np.ceil(len(img_data) / grid_width))\n",
        "  \n",
        "  # Pad to rectangular since tfgan.eval.python_image_grid\n",
        "  # expects this.\n",
        "  needed_images = num_rows * grid_width\n",
        "  tmp = np.zeros((needed_images, 28, 28, 1))\n",
        "  s = img_data.shape\n",
        "  tmp[:s[0], :s[1], :s[2]] = img_data\n",
        "  img_data = tmp\n",
        "  \n",
        "  img_grid = tfgan.eval.python_image_grid(\n",
        "      img_data, grid_shape=(num_rows, grid_width))\n",
        "\n",
        "  h = 20\n",
        "  w = h * (grid_width / num_rows)\n",
        "  plt.figure(figsize=(h, w))\n",
        "  plt.axis('off')\n",
        "  plt.imshow(np.squeeze(img_grid), cmap='binary')\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "xCfC_ro25y31"
      },
      "source": [
        "Display the data from clients with accuracy below a threshold."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "e6RZHtMjzDOL"
      },
      "outputs": [],
      "source": [
        "accuracy_by_client_id = {}\n",
        "THRESHOLD = 0.82\n",
        "\n",
        "for i, client_id in enumerate(emnist_train.client_ids):\n",
        "  raw_data = emnist_train.create_tf_dataset_for_client(client_id)\n",
        "  num_examples = sum([1 for _ in raw_data])\n",
        "  loss, accuracy = model.evaluate(preprocess_test(raw_data), verbose=0)\n",
        "  accuracy_by_client_id[client_id] = accuracy\n",
        "  if accuracy \u003c THRESHOLD:\n",
        "    print('client {} ({}) with {:3d} examples has accuracy {:6.2f}%'.format(\n",
        "        i, client_id, num_examples, 100 * accuracy))\n",
        "    display_raw_emnist(raw_data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "EnwfiozoYXzP"
      },
      "source": [
        "## Accuracy vs client_id\n",
        "Now, let's plot accuracy versus the (sorted) `client_id`s. We are interested in the general trend, so we use a moving average over clients."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Qp-E4z6h_sJq"
      },
      "outputs": [],
      "source": [
        "client_ids = sorted(list(emnist_train.client_ids))\n",
        "y = [accuracy_by_client_id[client_id] for client_id in client_ids]\n",
        "y = pd.Series(y).rolling(window=50).mean(center=True)\n",
        "plt.figure(figsize=(15, 3))\n",
        "plt.plot(range(len(y)), y)\n",
        "plt.title('Rolling mean accuracy vs client_id')\n",
        "plt.xlabel('client_id')\n",
        "plt.ylabel('Accuracy')\n",
        "plt.ylim(0.9, 1.0)\n",
        "s1 = client_ids.index('f2100_97')\n",
        "s2 = client_ids.index('f3100_44')\n",
        "x_loc = [500, 1000, 1500, s1, s2, 3000]\n",
        "plt.xticks(x_loc, [str(client_ids[x]) for x in x_loc])\n",
        "plt.vlines([s1, s2], 0.9, 1.0)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "5TBfxMytU7iC"
      },
      "source": [
        "There appears to be a correlation between the (sorted) `client_id`s and accuracy. This is likely due to the client_ids indicating the source; see Table 2 in the [User's Guide](https://s3.amazonaws.com/nist-srd/SD19/1stEditionUserGuide.pdf) for the NIST Special Database 19. Writers `f0000` - `f2099` were Census Bureau field personal, `f2100` - `f3099`  were high school students, and `f3100` - `f4099` were Census Bureau employees in Maryland.\n",
        "\n",
        "This finding implies that for centralized baseline training (as we did above), sufficient shuffling of the data is important (see also the TODO(b/135021147) above to improve this). \n",
        "\n",
        "For federated training, randomly sampling users is important; alternatively, this data could be used to simulate three different \"blocks\" of users to test the behavior of [Semi-Cyclic Stochastic Gradient Descent](https://arxiv.org/abs/1904.10120) as well as the mitigations suggested in the linked paper; non-federated experiments can be found [here](https://github.com/tensorflow/federated/tree/master/tensorflow_federated/python/research/semi_cyclic_sgd)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "v-ZmqrkufEz6"
      },
      "source": [
        "# Replicating the baseline with `tff.learning`\n",
        "\n",
        "Here we show how to use `tff.learning` to replicate the non-federated\n",
        "baseline. However, critically *the training is still essentially\n",
        "non-federated* --- that is, this is a santiy check, not a proper simulation of\n",
        "federated learning.\n",
        "\n",
        "The approach is based on the fact that if each \"client\" has IID shuffled data\n",
        "from a centralized training set, and we use the `FederatedAveraging` algorithm\n",
        "with one client per round (so there is no actual averaging), then this is\n",
        "algorithmically equivalent to running SGD centrally.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "567-s6ogbm-p"
      },
      "source": [
        "\n",
        "## Construct the `federated_averaging_process`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "1IoZ-T3g7kwT"
      },
      "outputs": [],
      "source": [
        "dummy_batch = tf.nest.map_structure(lambda x: x.numpy(),\n",
        "                                    next(iter(flat_train_data.take(1))))\n",
        "\n",
        "\n",
        "def create_tff_model():\n",
        "  keras_model = build_cnn()\n",
        "  return tff.learning.from_compiled_keras_model(\n",
        "      build_cnn(), dummy_batch=dummy_batch)\n",
        "\n",
        "\n",
        "fed_avg_process = tff.learning.build_federated_averaging_process(\n",
        "    model_fn=create_tff_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "NFom89Rg6X2I"
      },
      "source": [
        "## Helper for selecting datasets for each \"round\"\n",
        "We work around a dataset issue to construct a sequence of Datasets each containing\n",
        "`BATCHES_PER_ROUND` batches of size `BATCH_SIZE` from the flat shuffled\n",
        "training data.\n",
        "\n",
        "TODO(b/134945216): Once supported, use `tf.data.Dataset.window()` instead."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "QcXJcjuP6VXp"
      },
      "outputs": [],
      "source": [
        "# Dataset of \"big\" batches to work around window issue (b/134945216)\n",
        "tff_train_data = preprocess_train(\n",
        "    emnist_train.create_tf_dataset_from_all_clients(),\n",
        "    batch_size=BATCH_SIZE * BATCHES_PER_ROUND)\n",
        "\n",
        "tff_train_data_iter = iter(tff_train_data)\n",
        "\n",
        "\n",
        "def next_client_dataset():\n",
        "  # Grab the next \"big\" batch, create a dataset, and split into regular batches.\n",
        "  client_data = tf.data.Dataset.from_tensor_slices(next(tff_train_data_iter))\n",
        "  return client_data.batch(BATCH_SIZE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "wDojCQ2O6j_W"
      },
      "source": [
        "## Training and evaluation\n",
        "\n",
        "Now we are ready to do some training. We do 10 rounds of 1000 batches per round,\n",
        "but the split between rounds doesn't really matter."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "--5NwJKrqX0N"
      },
      "outputs": [],
      "source": [
        "state = fed_avg_process.initialize()\n",
        "\n",
        "print('Running Federated Averaging')\n",
        "start_time = time.time()\n",
        "for i in range(NUM_ROUNDS):\n",
        "  # Run one round of FederatedAveraging, on a single client.\n",
        "  round_start_time = time.time()\n",
        "  state, metrics = fed_avg_process.next(state, [next_client_dataset()])\n",
        "  finish_time = time.time()\n",
        "  print('Round {:3d} took {:6.2f} seconds (total {:4.0f} seconds). '\n",
        "        'Training metrics: {}'.format(i, finish_time - round_start_time,\n",
        "                                      finish_time - start_time, metrics))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "PuJhFT8L7VPj"
      },
      "outputs": [],
      "source": [
        "print('Final model evaluation on test data')\n",
        "keras_model = build_cnn()\n",
        "tff.learning.assign_weights_to_keras_model(keras_model, state.model)\n",
        "_ = keras_model.evaluate(flat_test_data)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "non_federated_emnist_baseline.ipynb",
      "provenance": [],
      "toc_visible": true,
      "version": "0.3.2"
    },
    "kernelspec": {
      "display_name": "Python 2",
      "name": "python2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
